/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <gtest/gtest.h>


#define protected public
#define private   public

//#include "graph_optimizer/graph_fusion/fusion_pass_manager/builtin_pass/logsoftmaxgrad_fusion_pass.h"
#include "graph/graph.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/attr_utils.h"
#include "graph/types.h"
#include "graph/ge_tensor.h"
#include "graph/op_desc.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_utils.h"
#include "common/configuration.h"
#include "graph/compute_graph.h"
#include "common/util/op_info_util.h"
#include "nn_norm_ops.h"
using namespace std;

static const char *SUB = "Sub";
static const char *MUL ="Mul";
static const char *EXP = "Exp";
static const char *TF_MATMUL = "Matmul";
static const char *ADD ="Add";

class UTEST_omg_optimizer_fusion_logsoftmaxgrad_fusion_unittest :public testing::Test {
protected:
    void SetUp()
    {
    }
    void TearDown()
    {

    }

protected:
    ge::NodePtr AddNode(ge::ComputeGraphPtr graph,
                        const string &name,
                        const string &type,
                        unsigned int in_anchor_num,
                        unsigned int out_anchor_num,
                        ge::DataType input_type,
                        ge::DataType output_type)
    {
        ge::GeTensorDesc input_tensor_desc(ge::GeShape({1, 2, 3}), ge::FORMAT_NHWC, input_type);
        ge::GeTensorDesc output_tensor_desc(ge::GeShape({1, 2, 3}), ge::FORMAT_NHWC, output_type);
        ge::OpDescPtr op_desc = make_shared<ge::OpDesc>(name, type);
        for (unsigned int i = 0; i < in_anchor_num; ++i) {
            op_desc->AddInputDesc(input_tensor_desc);
        }
        for (unsigned int i = 0; i < out_anchor_num; ++i) {
            op_desc->AddOutputDesc(output_tensor_desc);
        }
        ge::NodePtr node = graph->AddNode(op_desc);
        return node;
    }
    /*
                                                     add
                                                    /    \
                                                   /      \
                                                  /        \
                                                sum---mul---sub----matmul
                                                /      |
                                               /       |
                                              /        |
                                         constant     exp-----constant	
     */
    ge::ComputeGraphPtr CreateGraph()
    {
        // create compute graph
        ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("test_logsoftmaxgrad_fusion_graph");
        // create mul node
        ge::NodePtr mul_node = AddNode(graph, "mul", MUL, 2, 1, ge::DataType::DT_FLOAT16,
                                      ge::DataType::DT_FLOAT16);
        // create sum node
        ge::NodePtr sum_node = AddNode(graph, "sum", "Sum", 1, 1, ge::DataType::DT_FLOAT16, ge::DataType::DT_FLOAT);
        // create sub node
        ge::NodePtr sub_node = AddNode(graph, "/sub", SUB, 2, 2, ge::DataType::DT_FLOAT,
                                      ge::DataType::DT_FLOAT);
        // create exp node
        ge::NodePtr exp_node = AddNode(graph, "exp", EXP, 1, 1, ge::DataType::DT_FLOAT,
                                      ge::DataType::DT_FLOAT);
        // create cosnt node
        ge::NodePtr const_node = AddNode(graph, "constant", fe::CONSTANT, 1, 1, ge::DataType::DT_FLOAT,
                                        ge::DataType::DT_FLOAT);

        // create matmul node
        ge::NodePtr matmul_node = AddNode(graph, "matmul", TF_MATMUL, 1, 2, ge::DataType::DT_FLOAT,
                                         ge::DataType::DT_FLOAT);
        // create add node
        ge::NodePtr add_node = AddNode(graph, "add", ADD, 1, 2, ge::DataType::DT_FLOAT,
                                      ge::DataType::DT_FLOAT);
        // link sub node and matmul node
        ge::GraphUtils::AddEdge(sub_node->GetOutDataAnchor(0), matmul_node->GetInDataAnchor(0));
        // link sub node and mul node
        ge::GraphUtils::AddEdge(mul_node->GetOutDataAnchor(0), sub_node->GetInDataAnchor(1));
        // link add node and sub node
        ge::GraphUtils::AddEdge(add_node->GetOutDataAnchor(0), sub_node->GetInDataAnchor(0));
        // link add node and sum node
        ge::GraphUtils::AddEdge(add_node->GetOutDataAnchor(1), sum_node->GetInDataAnchor(0));
        // link sum node and mul node
        ge::GraphUtils::AddEdge(sum_node->GetOutDataAnchor(0), mul_node->GetInDataAnchor(0));
        // link exp node and mul node
        ge::GraphUtils::AddEdge(exp_node->GetOutDataAnchor(0), mul_node->GetInDataAnchor(1));
        // link exp node and const node
        ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), exp_node->GetInDataAnchor(0));
        // set sum input const node
        // ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), sum_node->GetInDataAnchor(1));
        ge::GeTensorDesc tensor_desc(ge::GeShape(), ge::FORMAT_NCHW, ge::DT_INT32);
        vector<int32_t> axis = {-1};
        ge::AttrUtils::SetListInt(sum_node->GetOpDesc(), "axes", axis);
        return graph;
    }
};

//TEST_F(UTEST_omg_optimizer_fusion_logsoftmaxgrad_fusion_unittest, fusion_success_first)
//{
//    ge::ComputeGraphPtr graph = CreateGraph();
//    fe::LogSoftmaxGradFusionPass pass;
//    fe::Status status = pass.Run(*graph);
//    EXPECT_EQ(fe::SUCCESS, status);
//    size_t node_num = graph->GetDirectNode().size();
//    EXPECT_EQ(node_num, 4);
//    vector<int32_t> axis;
//    if (graph->FindNode("/LogSoftmaxGrad0")) {
//        ge::AttrUtils::GetListInt(graph->FindNode("/LogSoftmaxGrad0")->GetOpDesc(), "axis", axis);
//    }
//    EXPECT_EQ(axis[0], -1);
//
//}
//
//TEST_F(UTEST_omg_optimizer_fusion_logsoftmaxgrad_fusion_unittest, fusion_success_second)
//{
//    ge::ComputeGraphPtr graph = CreateGraph();
//    ge::NodePtr sum_node = graph->FindNode("sum");
//    auto in_data_anchor = sum_node->GetInDataAnchor(0);
//    auto pre_out_data_anchor = in_data_anchor->GetPeerOutAnchor();
//    ge::GraphUtils::RemoveEdge(pre_out_data_anchor, in_data_anchor);
//    fe::LogSoftmaxGradFusionPass pass;
//    fe::Status status = pass.Run(*graph);
//    size_t node_num = graph->GetDirectNode().size();
//    EXPECT_EQ(node_num, 7);
//
//}
//
//TEST_F(UTEST_omg_optimizer_fusion_logsoftmaxgrad_fusion_unittest, fusion_success_third)
//{
//    ge::ComputeGraphPtr graph = CreateGraph();
//    ge::NodePtr sub_node = graph->FindNode("/sub");
//    auto in_data_anchor = sub_node->GetInDataAnchor(0);
//    auto pre_out_data_anchor = in_data_anchor->GetPeerOutAnchor();
//    ge::GraphUtils::RemoveEdge(pre_out_data_anchor, in_data_anchor);
//    ge::NodePtr exp_node = graph->FindNode("exp");
//    ge::GraphUtils::AddEdge(exp_node->GetOutDataAnchor(0), in_data_anchor);
//    fe::LogSoftmaxGradFusionPass pass;
//    fe::Status status = pass.Run(*graph);
//    size_t node_num = graph->GetDirectNode().size();
//    EXPECT_EQ(node_num, 7);
//}

